如何计算 Petaflops

import tensorflow as tf

tf = tf.compat.v1

  

def stats_graph(graph):

flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())

params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())

f = open('log.txt','w')

print('GFLOPs: {}; Trainable params: {}'.format(flops.total_float_ops / 1000000000.0, params.total_parameters), file=f)

f.close()

input_saved_model_dir = "./1719407478/"

with tf.Session(graph=tf.Graph()) as sess:

tf.saved_model.loader.load(sess, ["serve"], input_saved_model_dir)

graph = tf.get_default_graph()

stats_graph(graph)